import numpy as np
import cv2

def preprocess(img):
    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    # 高斯平滑
    gaussian = cv2.GaussianBlur(gray, (3, 3), 0, 0, cv2.BORDER_DEFAULT)
    # 中值滤波
    median = cv2.medianBlur(gaussian, 5)
    #cv2.imshow('gaussian&media', median)
    #cv2.waitKey(0)
    # 二值化
    ret, binary = cv2.threshold(gaussian, 225, 255, cv2.THRESH_BINARY)
    #ret, binary = cv2.threshold(sobel, 90, 255, cv2.THRESH_BINARY)
    #print("阈值:",ret)
    #cv2.imshow('binary', binary)
    #cv2.waitKey(0)
    # 膨胀和腐蚀操作的核函数
    element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (9, 1))
    element2 = cv2.getStructuringElement(cv2.MORPH_RECT, (9, 7))
    # 膨胀一次，让轮廓突出
    dilation = cv2.dilate(binary, element2, iterations=1)
    # 腐蚀一次，去掉细节
    erosion = cv2.erode(dilation, element1, iterations=1)
    #cv2.imshow('erosion', erosion)
    #cv2.waitKey(0)
    dilation2 = cv2.dilate(erosion, element2, iterations=1)
    #cv2.imshow('dilation2', dilation2)
    #cv2.waitKey(0)
    return binary,dilation2

def pickpoint(img):
    height, width = img.shape[:2]
    ImageArea=height*width
    # 查找轮廓,contours记录了每一个闭合的轮廓索引
    image_process,contours, hierarchy = cv2.findContours(img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    #print("轮廓数：",len(contours))
    #print("原图面积",ImageArea)
    edgemax=ImageArea*0.0036
    #print("阈值：",edgemax)
    # 筛选面积小的
    for i in range(len(contours)):
        cnt = contours[i]
        # 计算该轮廓的面积
        area = cv2.contourArea(cnt)
        #print("面积：",area)
        # 面积小的都筛选掉
        if (area < edgemax):
            rect = cv2.minAreaRect(cnt)
            #print("rect is: ", rect)
            # box是四个点的坐标
            box = cv2.boxPoints(rect)
            box = np.int0(box)
            #print("box",box)
            img=cv2.drawContours(img, [box], -1, (0, 0, 255), thickness=-1)
            #cv2.imshow("pick",img)
            #cv2.waitKey(0)
    #腐蚀一次
    element1 = cv2.getStructuringElement(cv2.MORPH_RECT, (9, 1))
    img = cv2.erode(img, element1, iterations=1)
    #cv2.imshow("erimg", img)
    #cv2.waitKey(0)
    return img


def find_end(start,arg,black,white,width,black_max,white_max):
    """找到每个脉冲的终止"""
    end=start+1
    for m in range(start+1,width-1):
        if (black[m] if arg else white[m])>(0.95*black_max if arg else 0.95*white_max):
            end=m
            break
    return end

def cut(thresh,binary):
    char=[]
    white=[]
    black=[]
    height=thresh.shape[0]
    width=thresh.shape[1]
    print('height',height)
    print('width',width)
    white_max=0
    black_max=0
    #计算每一列的黑白像素总和
    for i in range(width):
        line_white=0
        line_black=0
        for j in range(height):
            if thresh[j][i]==255:
                line_white+=1
            if thresh[j][i]==0:
                line_black+=1
        white_max=max(white_max,line_white)
        black_max=max(black_max,line_black)
        white.append(line_white)
        black.append(line_black)
        #print('white',white)
        #print('black',black)
    #arg为true表示黑底白字，False为白底黑字
    arg=True
    if black_max<white_max:
        arg=False
    n=1
    while n<width-2:
        n+=1
        #判断是白底黑字还是黑底白字  0.05参数对应上面的0.95 可作调整
        if(white[n] if arg else black[n])>(0.05*white_max if arg else 0.05*black_max):
            start=n
            end=find_end(start,arg,black,white,width,black_max,white_max)
            n=end
            if end-start>5:
                cj=binary[1:height,start:end]
                #左右填充
                cjwidth = cj.shape[1]
                cjheight=cj.shape[0]
                cj=cv2.copyMakeBorder(cj,0,0,int(cjwidth*0.2),int(cjwidth*0.2),cv2.BORDER_CONSTANT, value=0)
                #上下裁剪，因为要适应到数据集
                cjwidth = cj.shape[1]
                length=int(cjheight*0.25)
                cj=cj[length:cjheight-length,0:cjwidth]
                #平滑
                cj = cv2.GaussianBlur(cj, (3, 3), 0, 0, cv2.BORDER_DEFAULT)
                # 均值平滑
                cj = cv2.blur(cj, (3, 5))
                #直方图均衡化
                cj=cv2.equalizeHist(cj)
                #二值化扩大亮度
                ret, cj = cv2.threshold(cj, 60, 255, cv2.THRESH_BINARY)
                cj = cv2.GaussianBlur(cj, (3, 3), 0, 0, cv2.BORDER_DEFAULT)
                cj = cv2.blur(cj, (3, 5))
                char.append(cj)
                #print("result/%s.jpg" % (n))
                #cv2.imshow('cutlicense',cj)
                #cv2.waitKey(0)
    return char

if __name__ == '__main__':
    imagePath = './plate.png'
    img = cv2.imread(imagePath)
    cv2.imshow('img', img)
    binary,dilation2=preprocess(img)
    thresh=pickpoint(dilation2)
    charlist=cut(thresh,binary)
    cv2.waitKey(0)
    cv2.destroyAllWindows()